package drds.plus.sql_process.optimizer.chooser.join;

import drds.plus.sql_process.abstract_syntax_tree.expression.item.Item;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.function.BooleanFilter;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.function.Filter;
import drds.plus.sql_process.abstract_syntax_tree.expression.order_by.OrderBy;
import drds.plus.sql_process.abstract_syntax_tree.node.Node;
import drds.plus.sql_process.abstract_syntax_tree.node.query.$Query$;
import drds.plus.sql_process.abstract_syntax_tree.node.query.Join;
import drds.plus.sql_process.abstract_syntax_tree.node.query.Query;
import drds.plus.sql_process.abstract_syntax_tree.node.query.TableQuery;

import java.util.*;

/**
 * inner join的所有节点生成一个全排列
 */
public final class InnerJoinPermutationGenerator {

    private List<Item> itemList;
    private Filter resultFilter;
    private Filter allWhereFilter;
    private Filter subQueryFilter;
    private List<OrderBy> orderByList;
    private List<OrderBy> groupByList;
    private Comparable limitFrom;
    private Comparable limitTo;
    //
    private List<Query> queryList = new ArrayList();
    /**
     * A join B on A.id = B.id 转变为 {A.id-> {B.id->BF[A.id=B.id]}, B.id-> {A.id->BF[A.id=B.id]} }
     */
    private Map<Item, Map<Item, BooleanFilter>> leftJoinColumnToRightJoinColumnToFilterMapMap = new HashMap();
    private PermutationGenerator permutationGenerator = null;

    public InnerJoinPermutationGenerator(Query query) {
        this.itemList = query.getSelectItemList();
        this.resultFilter = query.getResultFilter();
        this.allWhereFilter = query.getAllWhereFilter();
        this.subQueryFilter = query.getSubQueryFunctionFilter();
        this.orderByList = query.getOrderByList();
        this.groupByList = query.getGroupByList();
        this.limitFrom = query.getLimitFrom();
        this.limitTo = query.getLimitTo();

        this.splitToQueryNodeList(query);
        this.permutationGenerator = new PermutationGenerator(this.queryList);
        addJoinFilterMapping(query);
    }

    /**
     * 从查询树中收集不会再被调整Join顺序的子树
     *
     * <pre>
     * 包括：
     * 1、非InnerJoin的Join节点。因为既然用户已经指定了left或者right，说明用户已经指定了outter的就为驱动表，所以无需再做额外的调整
     * 2、被标记为子查询的Join节点，子查询中的节点会单独去调整
     * 3、不在1中的Query节点
     * </pre>
     */
    private void splitToQueryNodeList(Query query) {
        if (query instanceof Join) {
            if (!((Join) query).isInnerJoin() || query.isSubQuery() || query.getOtherJoinOnFilter() != null) {
                this.queryList.add(query);
                return;
            }
        }
        if (query instanceof TableQuery || query instanceof $Query$) {
            this.queryList.add(query);
            return;
        }
        for (Node node : query.getNodeList()) {
            splitToQueryNodeList((Query) node);
        }

    }

    private void addJoinFilterMapping(Query query) {
        if (query instanceof Join) {
            List<Item> leftJoinItemList = ((Join) query).getLeftJoinItemList();
            List<Item> rightJoinItemList = ((Join) query).getRightJoinItemList();
            assert (leftJoinItemList.size() == rightJoinItemList.size());
            for (Filter filter : ((Join) query).getJoinFilterList()) {
                addJoinFilterMapping((BooleanFilter) filter);
            }
        }
        for (Node node : query.getNodeList()) {
            addJoinFilterMapping((Query) node);
        }
    }

    private void addJoinFilterMapping(BooleanFilter booleanFilter) {
        Item leftItem = (Item) booleanFilter.getColumn();
        Item rightItem = (Item) booleanFilter.getValue();
        if (!this.leftJoinColumnToRightJoinColumnToFilterMapMap.containsKey(leftItem)) {
            this.leftJoinColumnToRightJoinColumnToFilterMapMap.put(leftItem, new HashMap<Item, BooleanFilter>());
        }
        if (!this.leftJoinColumnToRightJoinColumnToFilterMapMap.containsKey(rightItem)) {
            this.leftJoinColumnToRightJoinColumnToFilterMapMap.put(rightItem, new HashMap<Item, BooleanFilter>());
        }
        //
        this.leftJoinColumnToRightJoinColumnToFilterMapMap.get(leftItem).put(rightItem, booleanFilter);// a b ab
        // 这里应该进行对调，map中应该维持左->右这个关系
        this.leftJoinColumnToRightJoinColumnToFilterMapMap.get(rightItem).put(leftItem, swapJoinColumn(booleanFilter));// b a ba
    }

    private BooleanFilter swapJoinColumn(BooleanFilter booleanFilter) {
        BooleanFilter copy = booleanFilter.copy();
        copy.setColumn(booleanFilter.getValue());
        copy.setValue(booleanFilter.getColumn());
        return copy;
    }

    public Query getNext() {
        while (permutationGenerator.hasNext()) {
            List<Query> queryList = permutationGenerator.next();
            for (int i = 0; i < queryList.size(); i++) {
                queryList.set(i, queryList.get(i).deepCopy());
            }
            Query query = join(queryList);
            if (query != null) {
                query.setSelectItemListAndSetNeedBuild(this.itemList);
                query.setResultFilterAndSetNeedBuild(this.resultFilter);
                query.setAllWhereFilter(allWhereFilter);
                query.setSubQueryFunctionFilter(subQueryFilter);
                query.setOrderByListAndSetNeedBuild(orderByList);
                query.setGroupByListAndSetNeedBuild(groupByList);
                query.setLimitFrom(limitFrom);
                query.setLimitTo(limitTo);
                return query;
            }
        }
        return null;
    }

    /**
     * 构造一个join
     */
    private Query join(List<Query> queryList) {
        if (queryList.size() == 1) {
            return queryList.get(0);
        }
        Join joinNode = null;
        for (int i = 1; i < queryList.size(); i++) {
            List<BooleanFilter> joinFilterList;
            if (joinNode == null) {
                joinFilterList = getJoinFilterList(queryList.get(i - 1), queryList.get(i));
                if (joinFilterList == null || joinFilterList.isEmpty()) {
                    return null;
                }
                joinNode = queryList.get(i - 1).join(queryList.get(i));
                joinNode.setJoinFilterList(joinFilterList);
            } else {
                joinFilterList = getJoinFilterList(joinNode, queryList.get(i));
                if (joinFilterList == null) {
                    return null;
                }
                joinNode = joinNode.join(queryList.get(i));
                joinNode.setJoinFilterList(joinFilterList);

            }

            joinNode.build();
        }
        return joinNode;
    }

    /**
     * 找到left/right节点存在的join条件
     */
    private List<BooleanFilter> getJoinFilterList(Query leftNode, Query rightNode) {
        List<BooleanFilter> filterList = new LinkedList<BooleanFilter>();
        for (Item leftItem : leftNode.copySelectItemList()) {
            if (!this.leftJoinColumnToRightJoinColumnToFilterMapMap.containsKey(leftItem)) {
                continue;
            }
            Map<Item, BooleanFilter> rightJoinColumnToFilterMap = this.leftJoinColumnToRightJoinColumnToFilterMapMap.get(leftItem);
            List<Item> rightItemList = rightNode.copySelectItemList();//join列在select 列中存在则添加
            for (Item rightItem : rightJoinColumnToFilterMap.keySet()) {
                if (rightItemList.contains(rightItem)) {
                    filterList.add(rightJoinColumnToFilterMap.get(rightItem));
                }
            }
        }
        return filterList;
    }
}
